import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import torch
import os
from tqdm.auto import tqdm
from glob import glob
import cv2
import numpy as np
import pandas as pd
import PIL
import urllib
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from random import uniform
# from imgaug import augmenters as iaa
%config InlineBackend.figure_format = 'retina'
%matplotlib inline
import torch.utils.data as td
import torchvision as tv
from PIL import Image
import matplotlib.pyplot as plt
import time
import pydicom as dcm
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
device
# !pip install git+https://github.com/albumentations-team/albumentations
# !pip install --user albumentations==1.1.0
import albumentations as A
import albumentations.pytorch
aug = A.Compose([
# A.HorizontalFlip(p=0.5,always_apply=False),
A.OneOf([
# A.Random/Contrast(always_apply=False,p=0.2,limit=[-0.1,0.1]),
# A.RandomGamma(always_apply=False,p=0.2,gamma_limit=[80,120]),
# A.RandomBrightness(always_apply=False,p=0.2,limit=[-0.1,0.1])
], p=0.3),
A.OneOf([
A.ElasticTransform(always_apply=False,p=0.2,alpha=10,sigma=6.0,alpha_affine=1.5999999999999996,interpolation=1,border_mode=4,approximate=False),
A.GridDistortion(always_apply=False,p=0.2,num_steps=1,distort_limit=[-0.1,0.1],interpolation=1,border_mode=4),
#albumentations.augmentations.transforms.OpticalDistortion(always_apply=False,p=0.4,distort_limit=[-2,2],shift_limit=[-0.5,0.5],interpolation=1,border_mode=4),
], p=0.3),
#albumentations.augmentations.transforms.Cutout(always_apply=False,p=0.5,num_holes=8,max_h_size=50,max_w_size=50)
#A.ShiftScaleRotate(always_apply=False,p=0.5,shift_limit=[-0.0625,0.0625],scale_limit=[-0.09999999999999998,0.12000000000000009],rotate_limit=[-25,25],interpolation=1,border_mode=4,value=None,mask_value=None),
# albumentations.augmentations.transforms.Resize(always_apply=True,p=1,height=512,width=512,interpolation=1)
])
from skimage.transform import resize
from skimage.io import imread
import numpy as np
import pydicom
def transform_to_hu(medical_image, image):
hu_image = image * medical_image.RescaleSlope + medical_image.RescaleIntercept
hu_image[hu_image < -1024] = -1024
return hu_image
def window_image(image, window_center, window_width):
window_image = image.copy()
image_min = window_center - (window_width / 2)
image_max = window_center + (window_width / 2)
window_image[window_image < image_min] = image_min
window_image[window_image > image_max] = image_max
return window_image
def resize_normalize(image):
image = np.array(image, dtype=np.float64)
image -= np.min(image)
image /= np.max(image)
return image
def read_dicom(image_medical, window_widht, window_level):
image_data = image_medical.pixel_array
image_hu = transform_to_hu(image_medical, image_data)
image_window = window_image(image_hu.copy(), window_level, window_widht)
image_window_norm = resize_normalize(image_window)
# image_window_norm = image_window
image_window_norm = np.expand_dims(image_window_norm, axis=2) # (512, 512, 1)
image_ths = np.concatenate([image_window_norm, image_window_norm, image_window_norm], axis=2) # (512, 512, 3)
#print(image_window_norm.shape)
return image_ths
def to_binary(img, lower, upper):
return (lower <= img) & (img <= upper)
def mask_binarization(mask, threshold=None):
if threshold is None:
threshold = 0.5
if isinstance(mask, np.ndarray):
mask_binarized = (mask > threshold).astype(np.uint8)
elif isinstance(mask, torch.Tensor):
zeros = torch.zeros_like(mask)
ones = torch.ones_like(mask)
mask_binarized = torch.where(mask > threshold, ones, zeros)
return mask_binarized
def augment_imgs_and_masks(imgs, masks, rot_factor, scale_factor, trans_factor, flip):
rot_factor = uniform(-rot_factor, rot_factor)
ran_alp = uniform(10,100)
scale_factor = uniform(1-scale_factor, 1+scale_factor)
trans_factor = [int(imgs.shape[1]*uniform(-trans_factor, trans_factor)),
int(imgs.shape[2]*uniform(-trans_factor, trans_factor))]
seq = iaa.Sequential([
iaa.Affine(
translate_px={"x": trans_factor[0], "y": trans_factor[1]},
scale=(scale_factor, scale_factor),
rotate=rot_factor
),
#iaa.ElasticTransformation(alpha=ran_alp,sigma=5.0)
])
seq_det = seq.to_deterministic()
imgs = seq_det.augment_images(imgs)
masks = seq_det.augment_images(masks)
if flip and uniform(0, 1) > 0.5:
imgs = np.flip(imgs, 2).copy()
masks = np.flip(masks, 2).copy()
masks = mask_binarization(masks).astype(np.float32)
return imgs, masks
# Data augmentation
rot_factor = 0
scale_factor = 0.0
flip = False
trans_factor = 0.0
class MyDataset(torch.utils.data.Dataset):
def __init__(self, x_dir, y_dir,augmentation=False):
super().__init__()
self.augmentation = augmentation
self.x_img = x_dir
self.y_img = y_dir
def __len__(self):
return len(self.x_img)
def __getitem__(self, idx):
x_img = self.x_img[idx]
y_img = self.y_img[idx]
# print(x_img)
# print(y_img)
# Read an image with OpenCV
if x_img[-1]=='m' or y_img[-1]=='g':
x_img = dcm.read_file(x_img)
x_img=read_dicom(x_img,400,0)
x_img=np.transpose(x_img,(2,0,1))
x_img=x_img.astype(np.float32)
y_img = imread(y_img)
y_img = resize(y_img, (512, 512))*255
color_im = np.zeros([512, 512, 2])
for i in range(1,3):
encode_ = to_binary(y_img, i*1.0, i*1.0)
color_im[:, :, i-1] = encode_
color_im = np.transpose(color_im,(2,0,1))
# y_img = resize(y_img, (512, 512))*255
else:
x_img = np.load(x_img)
x_img=resize_normalize(x_img)
y_img = np.load(y_img)
# print(x_img)
y_img = resize(y_img, (512, 512))
color_im = np.zeros([512, 512, 2])
for i in range(1,3):
encode_ = to_binary(y_img, i*1.0, i*1.0)
color_im[:, :, i-1] = encode_
color_im = np.transpose(color_im,(2,0,1))
image_window_norm = np.expand_dims(x_img, axis=2) # (512, 512, 1)
x_img = np.concatenate([image_window_norm, image_window_norm, image_window_norm], axis=2) # (512, 512, 3)
x_img=np.transpose(x_img,(2,0,1))
x_img=x_img.astype(np.float32)
# x_img = dcm.read_file(x_img)
# y_img = imread(y_img)
# x_img=read_dicom(x_img,400,50)
# x_img=np.transpose(x_img,(2,0,1))
# x_img=x_img.astype(np.float32)
# y_img = resize(y_img, (512, 512))*255
# color_im = np.zeros([512, 512, 2])
# for i in range(1,3):
# encode_ = to_binary(y_img, i*1.0, i*1.0)
# color_im[:, :, i-1] = encode_
# color_im = np.transpose(color_im,(2,0,1))
# Data Augmentation
if self.augmentation:
img, mask = augment_imgs_and_masks(x_img, color_im, rot_factor, scale_factor, trans_factor, flip)
return x_img, color_im,y_img
# if self.transforms:
# augmented = self.transforms(image=x_img,mask=color_im)
# img = augmented['image']
# mask = augmented['mask']
# return img, mask,y_img
# # return x_img,color_im
Adata_path_folder=sorted(os.listdir("./train/DICOM"))
label_path_folder=sorted(os.listdir("./train/label"))
#case 겹치지 않게 train,val 나누기
import glob
val_input_files=[]
val_label_files=[]
train_input_files=[]
train_label_files=[]
test_input_files=[]
test_label_files=[]
# for i in range(100):
# if i>79:
# val_input_files+=sorted(glob.glob("./train/DICOM/"+data_path_folder[i]+"/*.dcm",recursive=True))
# val_label_files+=sorted(glob.glob("./train/Label/"+label_path_folder[i]+"/*.png",recursive=True))
# else:
# train_input_files+=sorted(glob.glob("./train/DICOM/"+data_path_folder[i]+"/*.dcm",recursive=True))
# train_label_files+=sorted(glob.glob("./train/Label/"+label_path_folder[i]+"/*.png",recursive=True))
for i in range(102):
if i==0:
train_input_files+=sorted(glob.glob("./train/DICOM/"+Adata_path_folder[i]+"/*.npy",recursive=True))
train_label_files+=sorted(glob.glob("./train/Label/"+label_path_folder[i]+"/*.npy",recursive=True))
elif i<70:
train_input_files+=sorted(glob.glob("./train/DICOM/"+Adata_path_folder[i]+"/*.dcm",recursive=True))
train_label_files+=sorted(glob.glob("./train/Label/"+label_path_folder[i]+"/*.png",recursive=True))
elif i<90 :
val_input_files+=sorted(glob.glob("./train/DICOM/"+Adata_path_folder[i]+"/*.dcm",recursive=True))
val_label_files+=sorted(glob.glob("./train/Label/"+label_path_folder[i]+"/*.png",recursive=True))
elif i==101:
train_input_files+=sorted(glob.glob("./train/DICOM/"+Adata_path_folder[i]+"/*.npy",recursive=True))
train_label_files+=sorted(glob.glob("./train/Label/"+label_path_folder[i]+"/*.npy",recursive=True))
else:
test_input_files+=sorted(glob.glob("./train/DICOM/"+Adata_path_folder[i]+"/*.dcm",recursive=True))
test_label_files+=sorted(glob.glob("./train/Label/"+label_path_folder[i]+"/*.png",recursive=True))
len(train_input_files),len(val_input_files),len(test_input_files)
train_input_files = np.array(train_input_files)
train_label_files = np.array(train_label_files)
val_input_files = np.array(val_input_files)
val_label_files = np.array(val_label_files)
test_input_files = np.array(test_input_files)
test_label_files=np.array(test_label_files)
train_dataset = MyDataset(train_input_files,train_label_files)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=8,shuffle=True)
val_dataset = MyDataset(val_input_files,val_label_files)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=8,shuffle=True)
##input과 label이 맞나 확인
images,labels,a = next(iter(train_loader))
print(images.shape)
print(labels.shape)
print(labels[labels>=1])
plt.figure(figsize=(16,18))
plt.subplot(1,4,1)
plt.imshow(images[0][0],cmap='gray')
plt.subplot(1,4,2)
plt.imshow(labels[0][0])
plt.subplot(1,4,3)
plt.imshow(labels[0][1])
plt.subplot(1,4,4)
plt.imshow(a[0])
plt.show()
np.where(labels[0][0]>=0.1)
# len(labels[0][0]==0)
labels[0][0][306]
def compute_per_channel_dice(input, target, epsilon=1e-5,ignore_index=None, weight=None):
# assumes that input is a normalized probability
# input and target shapes must match
assert input.size() == target.size(), "'input' and 'target' must have the same shape"
# mask ignore_index if present
if ignore_index is not None:
mask = target.clone().ne_(ignore_index)
mask.requires_grad = False
input = input * mask
target = target * mask
input = flatten(input)
target = flatten(target)
# Compute per channel Dice Coefficient
intersect = (input * target).sum(-1)
if weight is not None:
intersect = weight * intersect
denominator = (input + target).sum(-1)
return 2. * intersect / denominator.clamp(min=epsilon)
def flatten(tensor):
"""Flattens a given tensor such that the channel axis is first.
The shapes are transformed as follows:
(N, C, D, H, W) -> (C, N * D * H * W)
"""
C = tensor.size(1)
# new axis order
axis_order = (1, 0) + tuple(range(2, tensor.dim()))
# Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
transposed = tensor.permute(axis_order).contiguous()
# Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
return transposed.view(C, -1)
class DiceLoss(nn.Module):
"""Computes Dice Loss, which just 1 - DiceCoefficient described above.
Additionally allows per-class weights to be provided.
"""
def __init__(self, epsilon=1e-5, weight=None, ignore_index=None, sigmoid_normalization=True,
skip_last_target=False):
super(DiceLoss, self).__init__()
if isinstance(weight, list):
weight = torch.Tensor(weight)
self.epsilon = epsilon
self.register_buffer('weight', weight)
self.ignore_index = ignore_index
if sigmoid_normalization:
self.normalization = nn.Sigmoid()
else:
self.normalization = nn.Softmax(dim=1)
# if True skip the last channel in the target
self.skip_last_target = skip_last_target
def forward(self, input, target):
# get probabilities from logits
input = self.normalization(input)
if self.weight is not None:
weight = Variable(self.weight, requires_grad=False).to(input.device)
else:
weight = None
if self.skip_last_target:
target = target[:, :-1, ...]
per_channel_dice = compute_per_channel_dice(input, target, epsilon=self.epsilon, ignore_index=self.ignore_index, weight=weight)
# Average the Dice score across all channels/classes
return torch.mean(1. - per_channel_dice)
# pip install git+https://github.com/qubvel/segmentation_models.pytorch
import segmentation_models_pytorch as smp
model = smp.FPN( #DeepLabV3
encoder_name="resnext101_32x8d",# choose encoder, e.g. mobilenet_v2 or efficientnet-b7 resnext101_32x8d,timm-res2net101_26w_4s # use `imagenet` pre-trained weights for encoder initialization
encoder_weights="imagenet",
in_channels=3,
# model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=2, # model output channels (number of classes in your dataset)
)
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True):
layers = []
layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
bias=bias)]
layers += [nn.BatchNorm2d(num_features=out_channels)]
layers += [nn.ReLU()]
cbr = nn.Sequential(*layers)
return cbr
self.enc1_1 = CBR2d(in_channels=3, out_channels=128)
self.enc1_2 = CBR2d(in_channels=128, out_channels=128)
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.enc2_1 = CBR2d(in_channels=128, out_channels=256)
self.enc2_2 = CBR2d(in_channels=256, out_channels=256)
self.pool2 = nn.MaxPool2d(kernel_size=2)
self.enc3_1 = CBR2d(in_channels=256, out_channels=512)
self.enc3_2 = CBR2d(in_channels=512, out_channels=512)
self.pool3 = nn.MaxPool2d(kernel_size=2)
self.enc4_1 = CBR2d(in_channels=512, out_channels=1024)
self.enc4_2 = CBR2d(in_channels=1024, out_channels=1024)
self.pool4 = nn.MaxPool2d(kernel_size=2)
self.enc5_1 = CBR2d(in_channels=1024, out_channels=2048)
self.dec5_1 = CBR2d(in_channels=2048, out_channels=1024)
self.unpool4 = nn.ConvTranspose2d(in_channels=1024, out_channels=1024,
kernel_size=2, stride=2, padding=0, bias=True)
self.dec4_2 = CBR2d(in_channels=2 * 1024, out_channels=1024)
self.dec4_1 = CBR2d(in_channels=1024, out_channels=512)
self.unpool3 = nn.ConvTranspose2d(in_channels=512, out_channels=512,
kernel_size=2, stride=2, padding=0, bias=True)
self.dec3_2 = CBR2d(in_channels=2 * 512, out_channels=512)
self.dec3_1 = CBR2d(in_channels=512, out_channels=256)
self.unpool2 = nn.ConvTranspose2d(in_channels=256, out_channels=256,
kernel_size=2, stride=2, padding=0, bias=True)
self.dec2_2 = CBR2d(in_channels=2 * 256, out_channels=256)
self.dec2_1 = CBR2d(in_channels=256, out_channels=128)
self.unpool1 = nn.ConvTranspose2d(in_channels=128, out_channels=128,
kernel_size=2, stride=2, padding=0, bias=True)
self.dec1_2 = CBR2d(in_channels=2 * 128, out_channels=128)
self.dec1_1 = CBR2d(in_channels=128, out_channels=128)
self.fc = nn.Conv2d(in_channels=128, out_channels=2, kernel_size=1, stride=1, padding=0, bias=True)
def forward(self, x):
enc1_1 = self.enc1_1(x)
enc1_2 = self.enc1_2(enc1_1)
pool1 = self.pool1(enc1_2)
enc2_1 = self.enc2_1(pool1)
enc2_2 = self.enc2_2(enc2_1)
pool2 = self.pool2(enc2_2)
enc3_1 = self.enc3_1(pool2)
enc3_2 = self.enc3_2(enc3_1)
pool3 = self.pool3(enc3_2)
enc4_1 = self.enc4_1(pool3)
enc4_2 = self.enc4_2(enc4_1)
pool4 = self.pool4(enc4_2)
enc5_1 = self.enc5_1(pool4)
dec5_1 = self.dec5_1(enc5_1)
unpool4 = self.unpool4(dec5_1)
cat4 = torch.cat((unpool4, enc4_2), dim=1)
dec4_2 = self.dec4_2(cat4)
dec4_1 = self.dec4_1(dec4_2)
unpool3 = self.unpool3(dec4_1)
cat3 = torch.cat((unpool3, enc3_2), dim=1)
dec3_2 = self.dec3_2(cat3)
dec3_1 = self.dec3_1(dec3_2)
unpool2 = self.unpool2(dec3_1)
cat2 = torch.cat((unpool2, enc2_2), dim=1)
dec2_2 = self.dec2_2(cat2)
dec2_1 = self.dec2_1(dec2_2)
unpool1 = self.unpool1(dec2_1)
cat1 = torch.cat((unpool1, enc1_2), dim=1)
dec1_2 = self.dec1_2(cat1)
dec1_1 = self.dec1_1(dec1_2)
x = self.fc(dec1_1)
return x
# model=UNet()
# # model
# pip install monai
# from monai.losses import DiceCELoss
# criterion = DiceCELoss(to_onehot_y=False, softmax=False)
# optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
# # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)
from sklearn.metrics import confusion_matrix
#mport numpy as np
def compute_iou(y_pred, y_true):
y_pred=y_pred.detach().cpu()
y_true=y_true.detach().cpu()
# ytrue, ypred is a flatten vector
y_pred = y_pred.flatten()
y_true = y_true.flatten()
current = confusion_matrix(y_true, y_pred,labels=[0,1])
# compute mean iou
intersection = np.diag(current)
ground_truth_set = current.sum(axis=1)
predicted_set = current.sum(axis=0)
union = ground_truth_set + predicted_set - intersection
IoU = intersection / union.astype(np.float32)
return np.mean(IoU)
sum([param.nelement() for param in model.parameters()])
# #model
# model.load_state_dict(torch.load('model_best_2.pt'))
import torch.optim as optim
criterion = DiceLoss(sigmoid_normalization=True)
# optimizer = optim.SGD(model.parameters(), lr=0.001, weight_decay=1e-8, momentum=0.9)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=6)
n_epochs =100
cnt =0
valid_loss_min = np.Inf # track change in validation loss
# keep track of training and validation loss
train_loss = torch.zeros(n_epochs)
valid_loss = torch.zeros(n_epochs)
Iou=0
model.to(device)
for e in range(0, n_epochs):
###################
# train the model #
###################
model.train()
for data, labels,a in tqdm(train_loader):
# move tensors to GPU if CUDA is available
data, labels = data.to(device), labels.to(device) #cpu에 있는 데이터를 gpu에 보냄
# clear the gradients of all optimized variables
# print(data.shape)
# break
optimizer.zero_grad()
# forward pass: compute predicted outputs by passing inputs to the model
logits = model(data)
# print(logits.shape)
# print(labels.shape)
# calculate the batch loss
loss = criterion(logits, labels)
# loss2 = criterion2(logits, labels)
# loss=(loss+loss2)/2
# backward pass: compute gradient of the loss with respect to model parameters
loss.backward()
# perform a single optimization step (parameter update)
optimizer.step()
# update training loss
train_loss[e] += loss.item()
# z=logits.detach().cpu().numpy()
# z = z.astype(np.uint8)
cnt = cnt+1
if cnt %50==0:
logits = logits.sigmoid()
logits = mask_binarization(logits.detach().cpu(), 0.5)
iou = compute_iou(logits,labels)
# iou= get_iou(logits,labels)
print(iou)
# y=torch.squeeze(labels[0])
y=logits[0].detach().cpu().numpy()
# x=data[0].detach().cpu().numpy()
x=labels[0].detach().cpu().numpy()
#y=labels[0].numpy()
plt.figure(figsize=(16,18))
plt.subplot(1,5,1)
plt.imshow(x[0])
plt.subplot(1,5,2)
plt.imshow(x[1])
plt.subplot(1,5,3)
plt.imshow(y[0])
plt.subplot(1,5,4)
plt.imshow(y[1])
plt.subplot(1,5,5)
plt.imshow(a[0])
plt.show()
train_loss[e] /= len(train_loader)
#torch.save(model.state_dict(), 'model_.pt')
######################
# validate the model #
######################
with torch.no_grad():
model.eval()
for data, labels,a in tqdm(val_loader):
# move tensors to GPU if CUDA is available
data, labels = data.to(device), labels.to(device)
# forward pass: compute predicted outputs by passing inputs to the model
logits = model(data)
# calculate the batch loss
loss = criterion(logits, labels)
# loss2 = criterion2(logits, labels)
# loss=(loss+loss2)/2
# update average validation loss
valid_loss[e] += loss.item()
# calculate average losses
valid_loss[e] /= len(val_loader)
# scheduler.step(valid_loss[e])
# print training/validation statistics
print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
e, train_loss[e], valid_loss[e]))
# save model if validation loss has decreased
if valid_loss[e] <= valid_loss_min:
print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(
valid_loss_min,
valid_loss[e]))
torch.save(model.state_dict(), 'model_best_2.pt')
valid_loss_min = valid_loss[e]
#Loss
plt.plot(train_loss)
plt.plot(valid_loss)
# model.load_state_dict(torch.load('model_.pt'))
model.load_state_dict(torch.load('model_best_2.pt'))
# test_input_files1=sorted(glob.glob("./test/TrainPlusGAN/train/*.npy",recursive=True))
# test_label_files1=sorted(glob.glob("./test/TrainPlusGAN/label/*.npy",recursive=True))
# test_input_files2=sorted(glob.glob("./test/TrainPlusGAN/train/**/*.dcm",recursive=True))
# test_label_files2=sorted(glob.glob("./test/TrainPlusGAN/label/**/*.png",recursive=True))
# test_input_files=test_input_files1+test_input_files2
# test_label_files=test_label_files1+test_label_files2
# test_input_files = np.array(test_input_files)
# test_label_files = np.array(test_label_files)
len(test_input_files)
class TestMyDataset(torch.utils.data.Dataset):
def __init__(self, x_dir,y_dir,augmentation = False):
super().__init__()
self.augmentation = augmentation
self.x_img = x_dir
self.y_img = y_dir
def __len__(self):
return len(self.x_img)
def __getitem__(self, idx):
x_img = self.x_img[idx]
y_img = self.y_img[idx]
# Read an image with OpenCV
x_img = dcm.read_file(x_img)
x_img=read_dicom(x_img,400,0)
x_img=np.transpose(x_img,(2,0,1))
x_img=x_img.astype(np.float32)
y_img = imread(y_img)
y_img = resize(y_img,(512,512))*255
color_im = np.zeros([512, 512, 2])
for i in range(1,3):
encode_ = to_binary(y_img, i*1.0, i*1.0)
color_im[:, :, i-1] = encode_
color_im = np.transpose(color_im,(2,0,1))
# Data Augmentation
if self.augmentation:
img, mask = augment_imgs_and_masks(x_img, color_im, rot_factor, scale_factor, trans_factor, flip)
return x_img,color_im,y_img
test_dataset = TestMyDataset(test_input_files,test_label_files)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1,shuffle=False)
len(test_loader)
images,labels,a = next(iter(test_loader))
print(images.shape)
print(labels.shape)
plt.figure(figsize=(16,18))
plt.subplot(1,4,1)
plt.imshow(images[0][0])
plt.subplot(1,4,2)
plt.imshow(labels[0][0])
plt.subplot(1,4,3)
plt.imshow(labels[0][1])
plt.subplot(1,4,4)
plt.imshow(a[0])
plt.show()
cnt =0
Iou=0
model.to(device)
with torch.no_grad():
model.eval()
for data, labels,a in tqdm(test_loader):
data, labels = data.to(device), labels.to(device)
# forward pass: compute predicted outputs by passing inputs to the model
logits = model(data)
logits = logits.sigmoid()
logits = mask_binarization(logits.detach().cpu(), 0.5)
iouu = compute_iou(logits,labels)
iouu=np.round(iouu,3)*100
iouu=np.nan_to_num(iouu)
# print(iouu)
Iou+=iouu
labels=labels[0].detach().cpu().numpy()
logits=logits[0].detach().cpu().numpy()
cnt = cnt+1
# if cnt %200==0:
# plt.figure(figsize=(16,18))
# plt.subplot(1,5,1)
# plt.imshow(labels[0])
# plt.subplot(1,5,2)
# plt.imshow(labels[1])
# plt.subplot(1,5,3)
# plt.imshow(logits[0])
# plt.subplot(1,5,4)
# plt.imshow(logits[1])
# plt.subplot(1,5,5)
# plt.imshow(a[0])
# plt.show()
print("Iou:",Iou//len(test_loader))
print(iou)